import torch
import torch.nn as nn
from promptda.model.dpt import DPTHead
from promptda.model.config import model_configs
from promptda.utils.logger import Log
import os
from pathlib import Path
from huggingface_hub import hf_hub_download


class PromptDA(nn.Module):
    patch_size = 14  # patch size of the pretrained dinov2 model
    use_bn = False
    use_clstoken = False
    output_act = 'sigmoid'

    def __init__(self,
                 encoder='vitl',
                 ckpt_path='checkpoints/promptda_vitl.ckpt'):
        super().__init__()
        model_config = model_configs[encoder]

        self.encoder = encoder
        self.model_config = model_config
        module_path = Path(__file__) # From anywhere else: module_path = Path(inspect.getfile(PromptDA))
        package_base_dir = str(Path(*module_path.parts[:-2])) # extract path to PromptDA module, then resolve to repo base dir for dinov2 load
        self.pretrained = torch.hub.load(
            f'{package_base_dir}/torchhub/facebookresearch_dinov2_main',
            'dinov2_{:}14'.format(encoder),
            source='local',
            pretrained=False)
        dim = self.pretrained.blocks[0].attn.qkv.in_features
        self.depth_head = DPTHead(nclass=1,
                                  in_channels=dim,
                                  features=model_config['features'],
                                  out_channels=model_config['out_channels'],
                                  use_bn=self.use_bn,
                                  use_clstoken=self.use_clstoken,
                                  output_act=self.output_act)

        # mean and std of the pretrained dinov2 model
        self.register_buffer('_mean', torch.tensor(
            [0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('_std', torch.tensor(
            [0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        self.load_checkpoint(ckpt_path)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs):
        """
        Load a model from a checkpoint file.
        ### Parameters:
        - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
        - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
        - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
        ### Returns:
        - A new instance of `MoGe` with the parameters loaded from the checkpoint.
        """
        ckpt_path = None
        if Path(pretrained_model_name_or_path).exists():
            ckpt_path = pretrained_model_name_or_path
        else:
            cached_checkpoint_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                repo_type="model",
                filename="model.ckpt",
                **hf_kwargs
            )
            ckpt_path = cached_checkpoint_path
        # model_config = checkpoint['model_config']
        # if model_kwargs is not None:
        #     model_config.update(model_kwargs)
        if model_kwargs is None:
            model_kwargs = {}
        model_kwargs.update({'ckpt_path': ckpt_path})
        model = cls(**model_kwargs)
        return model

    def load_checkpoint(self, ckpt_path):
        if os.path.exists(ckpt_path):
            Log.info(f'Loading checkpoint from {ckpt_path}')
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            self.load_state_dict(
                {k[9:]: v for k, v in checkpoint['state_dict'].items()})
        else:
            Log.warn(f'Checkpoint {ckpt_path} not found')

    def forward(self, x, prompt_depth=None):
        assert prompt_depth is not None, 'prompt_depth is required'
        prompt_depth, min_val, max_val = self.normalize(prompt_depth)
        h, w = x.shape[-2:]
        features = self.pretrained.get_intermediate_layers(
            (x - self._mean) / self._std, self.model_config['layer_idxs'],
            return_class_token=True)
        patch_h, patch_w = h // self.patch_size, w // self.patch_size
        depth = self.depth_head(features, patch_h, patch_w, prompt_depth)
        depth = self.denormalize(depth, min_val, max_val)
        return depth

    @torch.no_grad()
    def predict(self,
                image: torch.Tensor,
                prompt_depth: torch.Tensor):
        return self.forward(image, prompt_depth)

    def normalize(self,
                  prompt_depth: torch.Tensor):
        B, C, H, W = prompt_depth.shape
        min_val = torch.quantile(
            prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None]
        max_val = torch.quantile(
            prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None]
        prompt_depth = (prompt_depth - min_val) / (max_val - min_val)
        return prompt_depth, min_val, max_val

    def denormalize(self,
                    depth: torch.Tensor,
                    min_val: torch.Tensor,
                    max_val: torch.Tensor):
        return depth * (max_val - min_val) + min_val
